from functools import cached_property
from typing import Any, Literal, TypedDict

from annotated_types import Ge, Le
from pydantic import BaseModel, Field, NonNegativeInt, RootModel, computed_field, model_validator
from typing_extensions import Annotated


class TokenUsage(BaseModel):
    input_tokens: NonNegativeInt = 0
    """Number of tokens used as input to the model, inclusive of cached_input_tokens."""

    cached_input_tokens: NonNegativeInt = 0
    """Number of input tokens that were already in the KV cache."""

    forward_passes: NonNegativeInt = 0
    """Number of forward passes made to the model."""

    cached_output_tokens: NonNegativeInt = 0
    """Number of forward passes we avoided by hitting the KV cache."""

    ff_tokens: NonNegativeInt | None = None
    """Number of output tokens that were fast-forwarded by the parser (if applicable)."""

    round_trips: NonNegativeInt = 0
    """Number of times a completion was generated. For remote models, this is the number of
    API calls made to the model. For local models, this is the number of times we entered a
    completion loop to generate output tokens."""

    total_latency_ms: Annotated[float, Ge(0)] = 0.0
    """Total latency of the model in milliseconds. This includes the time spent on all forward passes and
    the time spent on fast-forwarding tokens (if applicable)."""

    ttft_ms: Annotated[float, Ge(0)] = 0.0
    """Time to first token in ms"""

    mask_times_ms: list[Annotated[float, Ge(0)]] = Field(default_factory=list)
    """List of mask times in ms for each token generated."""

    mask_overheads_ms: list[Annotated[float, Ge(0)]] = Field(default_factory=list)
    """List of mask overhead times in ms for each token generated."""

    ttfm_ms: Annotated[float, Ge(0)] = 0.0
    """Time to first mask in ms"""

    @computed_field  # type: ignore[misc]
    @property
    def output_tokens(self) -> NonNegativeInt:
        """Total number of output tokens generated by the model, inclusive of cached_output_tokens and ff_tokens.

        Note that this may overcount the actual number of tokens generated if some tokens were backtracked or some
        forward passes were used just to pre-fill or compute probabilities for visualization"""
        return self.forward_passes + self.cached_output_tokens + (self.ff_tokens or 0)

    @computed_field  # type: ignore[misc]
    @property
    def token_savings(self) -> Annotated[float, Ge(0), Le(1)] | None:
        """The fraction of output tokens that were fast-forwarded by the parser (if applicable)."""
        if self.ff_tokens is None:
            return None
        if self.output_tokens == 0:
            return 0.0
        return self.ff_tokens / self.output_tokens

    @computed_field  # type: ignore[misc]
    @property
    def avg_latency_ms(self) -> float:
        """Average latency of tokens generated by the model."""
        if self.output_tokens == 0:
            return 0.0
        return self.total_latency_ms / self.output_tokens

    def __add__(self, other: "TokenUsage") -> "TokenUsage":
        if self.ff_tokens is None and other.ff_tokens is None:
            ff_tokens = None
        else:
            ff_tokens = (self.ff_tokens or 0) + (other.ff_tokens or 0)

        ttft_ms = other.ttft_ms
        ttfm_ms = other.ttfm_ms

        return TokenUsage(
            ff_tokens=ff_tokens,
            ttft_ms=ttft_ms,
            ttfm_ms=ttfm_ms,
            **{
                field: getattr(self, field) + getattr(other, field)
                for field in set(self.__class__.model_fields) - {"ff_tokens", "ttft_ms", "ttfm_ms"}
            },
        )


class EngineResponse(BaseModel):
    new_bytes: bytes
    backtrack_bytes: bytes
    capture_groups: dict
    capture_group_log_probs: dict
    backtrack: NonNegativeInt = 0  # number of tokens was backtracked by the parser
    tokens: list["GenToken"] = []  # tokens associated with the generated bytes


class LegacyEngineCallResponse(BaseModel):
    new_bytes: bytes
    is_generated: bool
    new_bytes_prob: float
    capture_groups: dict
    capture_group_log_probs: dict
    new_token_count: NonNegativeInt
    backtrack: NonNegativeInt = 0  # number of tokens was backtracked by the parser
    latency_ms: NonNegativeInt = 0  # time taken by the engine to generate the output chunk
    generated_bytes: bytes = b""  # bytes generated by the engine
    generated_tokens: list["GenToken"] = []  # tokens associated with the generated bytes
    force_forwarded_bytes: bytes = b""  # bytes that were forced forwards by the parser
    force_forwarded_tokens: list["GenToken"] = []  # tokens associated with the forced forwarded bytes


class GenToken(BaseModel):
    token_id: int
    bytes: bytes
    prob: float = float("nan")
    latency_ms: float = 0.0
    is_masked: bool = False  # true if this token was ignored by the parser
    is_generated: bool = False  # true if this token was generated by the engine
    is_force_forwarded: bool = False  # true if this token was forced forwarded by the parser
    is_input: bool = False  # true if this token was part of the input
    is_backtracked: bool = False  # true if this token was backtracked by the parser


class GenTokenExtra(GenToken):
    top_k: list["GenToken"] = []


class GenData(BaseModel):
    tokens: list[int]
    mask: bytes
    temperature: float

    @computed_field  # type: ignore[misc]
    @cached_property
    def valid_next_tokens(self) -> list[int]:
        return [i for i, b in enumerate(self.mask) if b != 0]


class LLProgressCapture(BaseModel):
    object: Literal["capture"]
    name: str
    hex: str
    log_prob: float
    list_append: bool = False

    @model_validator(mode="before")
    def strip_list_append_prefix(cls, values):
        name = values["name"]
        if name.startswith("__LIST_APPEND:"):
            values["name"] = name[14:]
            # Override whatever was set
            values["list_append"] = True
        return values


class LLProgressText(BaseModel):
    object: Literal["text"]
    hex: str
    num_tokens: NonNegativeInt
    log_prob: float
    is_generated: bool


class LLProgressFinalText(BaseModel):
    object: Literal["final_text"]
    # we don't need to handle this for now


LLProgressItem = Annotated[
    LLProgressCapture | LLProgressText | LLProgressFinalText,
    Field(discriminator="object"),
]


class LLProgress(RootModel):
    root: list[LLProgressItem]

    def to_engine_call_response(self) -> LegacyEngineCallResponse:
        new_bytes = b""
        new_token_count = 0
        new_bytes_prob = 0.0
        is_generated = False
        capture_groups: dict[str, Any] = {}
        capture_group_log_probs: dict[str, Any] = {}
        num_text_entries = 0

        for j in self.root:
            if isinstance(j, LLProgressCapture):
                is_generated = True
                cname = j.name
                data = bytes.fromhex(j.hex)
                if j.list_append:
                    if cname not in capture_groups or not isinstance(capture_groups[cname], list):
                        capture_groups[cname] = []
                        capture_group_log_probs[cname] = []
                    capture_groups[cname].append(data)
                    capture_group_log_probs[cname].append(j.log_prob)
                else:
                    capture_groups[cname] = data
                    capture_group_log_probs[cname] = j.log_prob
            elif isinstance(j, LLProgressText):
                # it actually should only happen once per round...
                new_bytes += bytes.fromhex(j.hex)
                new_token_count += j.num_tokens
                new_bytes_prob += j.log_prob
                is_generated |= j.is_generated
                num_text_entries += 1
        if num_text_entries > 0:
            new_bytes_prob /= num_text_entries

        return LegacyEngineCallResponse(
            new_bytes=new_bytes,
            new_token_count=new_token_count,
            new_bytes_prob=new_bytes_prob,
            is_generated=is_generated,
            capture_groups=capture_groups,
            capture_group_log_probs=capture_group_log_probs,
        )


class LLInterpreterResponse(BaseModel):
    progress: LLProgress
    stop: bool
    temperature: float | None


class SamplingParams(TypedDict):
    top_p: float | None
    top_k: int | None
    min_p: float | None
    repetition_penalty: float | None
